# -*- coding: utf-8 -*-

from keras.models import load_model
import tensorflow as tf
import esc10_input
import numpy as np
import os

import matplotlib.pyplot as plt


plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False



def use_gpu():
    """Configuration for GPU"""
    from keras.backend import set_session
    # from keras.backend.tensorflow_backend import set_session
    os.environ['CUDA_VISIBLE_DEVICES'] = str(0)
    config = tf.ConfigProto()
    config.gpu_options.per_process_gpu_memory_fraction = 0.5
    config.gpu_options.allow_growth = True
    set_session(tf.InteractiveSession(config=config))


from sklearn.metrics import confusion_matrix

import numpy as np
import matplotlib.pyplot as plt
import numpy as np
import itertools

def plot_confusion_matrix(cm,
                          target_names,
                          title='Confusion matrix',
                          cmap=plt.cm.Greens,  # 这个地方设置混淆矩阵的颜色主题，这个主题看着就干净~
                          normalize=True):
    accuracy = np.trace(cm) / float(np.sum(cm))
    misclass = 1 - accuracy

    if cmap is None:
        cmap = plt.get_cmap('Blues')

    plt.figure(figsize=(10,15))
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()

    if target_names is not None:
        tick_marks = np.arange(len(target_names))
        plt.xticks(tick_marks, target_names, rotation=45,fontsize = 15)
        plt.yticks(tick_marks, target_names,fontsize = 15)

    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

    thresh = cm.max() / 1.5 if normalize else cm.max() / 2
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        if normalize:
            plt.text(j, i, "{:0.4f}".format(cm[i, j]),
                     horizontalalignment="center",
                     color="white" if cm[i, j] > thresh else "black",fontsize = 15)
        else:
            plt.text(j, i, "{:,}".format(cm[i, j]),
                     horizontalalignment="center",
                     color="white" if cm[i, j] > thresh else "black",fontsize = 15)

    plt.tight_layout()
    plt.ylabel('True label',fontsize=20)
    plt.xlabel('Predicted label\naccuracy={:0.4f}; misclass={:0.4f}'.format(accuracy, misclass),fontsize=20)
    # 这里这个savefig是保存图片，如果想把图存在什么地方就改一下下面的路径，然后dpi设一下分辨率即可。
    # plt.savefig('./picture/confusionmatrix32.png',dpi=300)
    plt.show()

def CNN_test(test_fold, feat):
    """
    Test models using test set
    :param test_fold: test fold of 5-fold cross validation
    :param feat: which feature to use
    """
    # 读取测试数据
    _, _, test_features, test_labels = esc10_input.get_data(test_fold, feat)
    # print(test_features)

    # '狗', '公鸡', '雨', '海浪', '火', '婴儿哭', '喷嚏', '时钟滴答', '直升飞机', '电锯'

    print("test_features: {}, test_labels: {}".format(len(test_features), len(test_labels)))
    # print("test_features: {}, test_labels: {}".format(test_features, test_labels))
    # print(test_features.shape)  # (80, 60, 65, 1)

    # test_features = np.argmax(test_features, axis=1)
    # print("随机选取的80个样本为：", np.array(label)[test_features])

    # Train on 288 samples, validate on 32 samples

    # 导入训练好的模型
    model = load_model('./saved_model/cnn_{}_fold{}.h5'.format(feat, test_fold))

    #########################################################################
    labels = ['狗', '公鸡', '雨', '海浪', '火', '婴儿哭', '喷嚏', '时钟滴答', '直升飞机', '电锯']

    predictions = model.predict(test_features, batch_size=32)
    predictions = np.argmax(predictions, axis=-1)
    truelabel = test_labels.argmax(axis=-1)
    conf_mat = confusion_matrix(y_true=truelabel, y_pred=predictions)
    # plt.figure()
    plot_confusion_matrix(conf_mat, normalize=True, target_names=labels, title='Confusion Matrix')

    ############################################################################

    # 输出训练好的模型在测试集上的表现
    score = model.evaluate(test_features, test_labels)
    print('Test score:', score[0])
    print('Test accuracy:', score[1])

    return score[1]

# 显示混淆矩阵
# def plot_confuse(model, x_val, y_val):
#     predictions = model.predict_classes(x_val, batch_size=32)
#     truelabel = y_val.argmax(axis=-1)  # 将one-hot转化为label
#     conf_mat = confusion_matrix(y_true=truelabel, y_pred=predictions)
#     plt.figure()
#     plot_confusion_matrix(conf_mat, normalize=False, target_names=labels, title='Confusion Matrix')


# =========================================================================================
# 最后调用这个函数即可。 test_x是测试数据，test_y是测试标签（这里用的是One——hot向量）
# labels是一个列表，存储了你的各个类别的名字，最后会显示在横纵轴上。
# 比如这里我的labels列表
# labels = ['StandingUpFS', 'StandingupFL', 'Walking', 'Running', 'GoingUpS', 'Jumping', 'GoingdownS', 'LyingDownS',
#           'SittingDown',
#           'Falling Forw',
#           'Falling right', 'FallingBack', 'HittingObstacle', 'Falling with ps', 'FallingBackSC', 'Syncope',
#           'falling left']

# plot_confuse(model, test_x, test_y,labels)


if __name__ == '__main__':
    # MFCC
    # use_gpu()  # 使用GPU
    # dict_acc = {}
    # print('### [Start] Test models for ESC10 dataset #####')
    # for fold in [1, 2, 3, 4, 5]:
    #     print("## Start test fold{} models #####".format(fold))
    #     acc = CNN_test(fold, 'mfcc')
    #     dict_acc['fold{}'.format(fold)] = acc
    #     print("## Finish test fold{} models #####".format(fold))
    # dict_acc['mean'] = np.mean(list(dict_acc.values()))
    # print(dict_acc)
    # print('### [Finish] Test models finished for ESC10 dataset #####')

    use_gpu()  # 使用GPU
    dict_acc = {}
    print('### [Start] Test models for ESC10 dataset #####')
    for fold in [1, 2, 3, 4, 5]:
        print("## Start test fold{} models #####".format(fold))
        acc = CNN_test(fold, 'logmel')
        dict_acc['fold{}'.format(fold)] = acc
        print("## Finish test fold{} models #####".format(fold))
    dict_acc['mean'] = np.mean(list(dict_acc.values()))
    print(dict_acc)
    print('### [Finish] Test models finished for ESC10 dataset #####')
